import logging
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib import dates
from pandas import Timestamp

from market_simulation.agents.noise_agent import NoiseAgent
from market_simulation.states.trade_info_state import TradeInfoState
from mlib.core.env import Env
from mlib.core.event import create_exchange_events
from mlib.core.exchange import Exchange
from mlib.core.exchange_config import create_exchange_config_without_call_auction
from mlib.core.trade_info import TradeInfo


def run_basic_simulation(seed: int = 0) -> None:
    """Run a basic market simulation with a noise agent.

    This function creates a simplified market environment with a single agent
    that generates random order flow. The simulation runs for a specified time period
    and then visualizes the resulting price trajectory.

    Args:
        seed: Random seed for reproducible simulation results

    Returns:
        None
    """
    # Setup simulation parameters
    symbols = ["000000"]
    start_time = Timestamp("2024-01-01 09:30:00")
    end_time = Timestamp("2024-01-01 10:30:00")

    # Create exchange environment
    exchange_config = create_exchange_config_without_call_auction(
        market_open=start_time,
        market_close=end_time,
        symbols=symbols,
    )
    exchange = Exchange(exchange_config)

    # Initialize noise agent for order generation
    agent = NoiseAgent(
        symbol=symbols[0],
        init_price=100000,
        interval_seconds=1,
        start_time=start_time,
        end_time=end_time,
        seed=seed,
    )

    # Configure simulation environment
    exchange.register_state(TradeInfoState())
    env = Env(exchange=exchange, description="Noise agent simulation")
    env.register_agent(agent)
    env.push_events(create_exchange_events(exchange_config))

    # Run simulation
    for observation in env.env():
        action = observation.agent.get_action(observation)
        env.step(action)

    # Extract and visualize results
    trade_infos: list[TradeInfo] = extract_trade_information(exchange, symbols[0], start_time, end_time)
    logging.info(f"Collected {len(trade_infos)} trade information records.")
    visualize_price_trajectory(trade_infos, Path("tmp/price_curves.png"))


def extract_trade_information(exchange: Exchange, symbol: str, start_time: Timestamp, end_time: Timestamp) -> list[TradeInfo]:
    """Extract trade information from a completed simulation.

    Retrieves trade data from the exchange's TradeInfoState and filters it
    to the specified time range.

    Args:
        exchange: The exchange instance containing simulation states
        symbol: Market symbol to extract data for
        start_time: Beginning of time range for filtering
        end_time: End of time range for filtering

    Returns:
        list[TradeInfo]: Filtered trade information records
    """
    state = exchange.states()[symbol][TradeInfoState.__name__]
    assert isinstance(state, TradeInfoState)
    trade_infos = state.trade_infos
    trade_infos = [x for x in trade_infos if start_time <= x.order.time <= end_time]
    return trade_infos


def visualize_price_trajectory(trade_infos: list[TradeInfo], path: Path) -> None:
    """Visualize price trajectory from simulation results.

    Creates a time series plot showing how prices evolved during the simulation.
    The data is aggregated by minute to smooth the visualization and show clear trends.

    Args:
        trade_infos: List of TradeInfo objects containing price data
        path: Location to save the generated visualization

    Returns:
        None
    """
    # Create output directory if it doesn't exist
    path.parent.mkdir(parents=True, exist_ok=True)

    # Extract valid price data points
    prices = [
        {
            "Time": x.order.time,
            "Price": x.lob_snapshot.last_price,
        }
        for x in trade_infos
        if x.lob_snapshot.last_price > 0
    ]

    # Aggregate data by minute for clearer visualization
    price_data = pd.DataFrame(prices).groupby(pd.Grouper(key="Time", freq="1min")).mean().reset_index()

    # Configure visualization
    sns.set_style("darkgrid")
    fig, ax = plt.subplots(figsize=(5, 3))

    # Create price trajectory plot
    sns.lineplot(x="Time", y="Price", data=price_data, ax=ax)

    # Format time axis to show hour:minute format
    ax.xaxis.set_major_formatter(dates.DateFormatter("%H:%M"))

    # Add title and adjust layout
    ax.set_title("Price Trajectory Generated by NoiseAgent")
    fig.tight_layout()

    # Save visualization and close plot
    fig.savefig(str(path))
    plt.close(fig)
    logging.info(f"Saved price trajectory visualization to {path}")


if __name__ == "__main__":
    run_basic_simulation(seed=2)
